# Copyright (c) 2023 Copyright holder of the paper "Revisiting Image Classifier Training for Improved Certified Robust Defense against Adversarial Patches" submitted to TMLR for review

# All rights reserved.

import argparse
import numpy as np
import math
import random
import time
import os

import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision
import timm

from build import generate_masks
from utils import *

# python train_advmask_gridsearch.py --dataset imagenet -k 3 -lr 0.01 --pretrained-model resnetv2 -c gridsearchadv_maskset3x3

pretrained_model_options = ["resnetv2", "vit_base", "convnext"]
datasets = ["imagenet", "cifar10", "cifar100", "imagenette", 'svhn']

parser = argparse.ArgumentParser(description='Worst Case Single Mask Training')
parser.add_argument('-j', '--workers', default=6, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-ep', '--epochs', default=10, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=128, type=int,
                    metavar='N')
parser.add_argument('-lr', '--learning-rate', default=0.01, type=float,
                    metavar='N')
parser.add_argument('-p', '--print-freq', default=1000, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--seed', default=2022, type=int,
                    help='seed for initializing training. ')
parser.add_argument('-c', '--checkpoint-name', type=str, help='checkpoint name.')
parser.add_argument('--pretrained-model', type=str, default="resnetv2",
                    choices=pretrained_model_options)
parser.add_argument('--optim', metavar='optim', default='sgd',
                    choices=["sgd", "adam"])
parser.add_argument('-r', '--resume', dest='resume', action='store_true',
                    help='resume model from last checkpoint')
parser.add_argument('--task', metavar='task', default='trainval',
                    choices=["trainval", "train", "val"])
parser.add_argument('--dataset', metavar='dataset', default='imagenet',
                    choices=datasets)
parser.add_argument('-k', '--num-mask-locations', default=3, type=int,
                    metavar='N') # e.g. 3x3 = 9 total mask patches
args = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def girdsearch_worstcase_masks(clean_images, labels, classifier, masks_set_view):
    B = clean_images.shape[0]
    mask_set_size = masks_set_view.shape[1]
    losses = torch.zeros(B, mask_set_size)
    with torch.no_grad():
        masked_images_set = (1 - masks_set_view) * clean_images.unsqueeze(dim=1) + masks_set_view * 0.5
        for idx in range(mask_set_size):
            output = classifier(normalize(masked_images_set[:,idx]))
            loss = F.cross_entropy(output, labels, reduction="none")
            losses[:, idx] = loss
        worst_mask_ids = torch.argmax(losses, dim=1)
        masked_images_set = masked_images_set[range(B), worst_mask_ids]
    return masked_images_set


def search_worstcase_masks(clean_images, labels, classifier, masks_set):
    B = clean_images.shape[0]
    masked_images_set = clean_images
    mask_set_size = masks_set.shape[1]
    masks_set_view = masks_set.unsqueeze(dim=2).to(clean_images.device)
    losses = torch.zeros(B, mask_set_size)
    with torch.no_grad():
        for round in range(2):
            masked_images_set = masked_images_set.unsqueeze(dim=1)
            masked_images_set = (1 - masks_set_view) * masked_images_set + masks_set_view * 0.5
            for idx in range(mask_set_size):
                output = classifier(normalize(masked_images_set[:,idx]))
                loss = F.cross_entropy(output, labels, reduction="none")
                losses[:, idx] = loss
            worst_mask_ids = torch.argmax(losses, dim=1)
            masked_images_set = masked_images_set[range(B), worst_mask_ids]
    return masked_images_set


def validate(val_loader, classifier, masks_set):
    classifier.eval()

    top1 = AverageMeter('Clean Acc@1', ':6.2f')
    top1_adv = AverageMeter('Adv Acc@1', ':6.2f')

    eval_start_time = time.time()

    with torch.no_grad():
        for i, data in enumerate(val_loader):
            clean_images, target = data[0].cuda(), data[1].cuda()

            adv_cutout_images = search_worstcase_masks(clean_images, target, classifier, masks_set)
            images = torch.cat((clean_images, adv_cutout_images), dim=0)

            output = classifier(normalize(images))
            output, adv_cutout_output = torch.split(output, len(target))

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], target.size(0))

            acc1, acc5 = accuracy(adv_cutout_output, target, topk=(1, 5))
            top1_adv.update(acc1[0], target.size(0))

        print(
            'Clean Acc@1 {top1.avg:.3f} Adv Acc@1 {top1_adv.avg:.3f}  Time {Time:.3f} secs'.format(
                top1=top1, top1_adv=top1_adv, Time=time.time() - eval_start_time))


def validate_clean(val_loader, classifier):
    classifier.eval()
    top1 = AverageMeter('Clean Acc@1', ':6.2f')
    eval_start_time = time.time()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            clean_images, target = data[0].cuda(), data[1].cuda()
            output = classifier(normalize(clean_images))
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], target.size(0))
        print('Clean Acc@1 {top1.avg:.3f} Time {Time:.3f} secs'.format(top1=top1, Time=time.time() - eval_start_time))


def train(train_loader, classifier, optimizer, epoch, masks_set_view, args):
    epoch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Clean Acc@1', ':6.2f')
    top1_adv = AverageMeter('Adv Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [losses, top1, top1_adv, epoch_time],
        prefix="Epoch: [{}]".format(epoch))

    classifier.train()

    epoch_start_time = time.time()

    if epoch == args.epochs - 5:
        lr = optimizer.param_groups[0]['lr']
        print(f"Changing the learning rate of Classifier from {lr} to {lr * 0.1:.4f}")
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr * 0.1

    for i, data in enumerate(train_loader):
        clean_images, labels = data[0].cuda(), data[1].cuda()

        optimizer.zero_grad()

        adv_cutout_images = girdsearch_worstcase_masks(clean_images, labels, classifier, masks_set_view)
        # images = torch.cat((clean_images, adv_cutout_images), dim=0)

        output = classifier(normalize(adv_cutout_images))
        loss = get_classification_loss(output, labels)
        # output, adv_cutout_output = torch.split(output, len(labels))

        # compute gradient and do gradient step
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        losses.update(loss.item(), labels.size(0))

        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], labels.size(0))

        # acc1, acc5 = accuracy(adv_cutout_output, labels, topk=(1, 5))
        # top1_adv.update(acc1[0], labels.size(0))

        # measure elapsed time
        epoch_time.update(time.time() - epoch_start_time)

        if i % args.print_freq == 0:
            progress.display(i)
            fig_path = "figures/"
            iter_fig_path = os.path.join(fig_path, f"iter_{i}")
            if not os.path.exists(iter_fig_path):
                os.mkdir(iter_fig_path)
            for img_idx in range(10):
                torchvision.utils.save_image(adv_cutout_images[[img_idx]].cpu(), iter_fig_path+f"/{img_idx}.png")


def main():
    cudnn.benchmark = True
    args = parser.parse_args()
    if args.seed is not None:
        # set the seed
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    checkpoint_path = set_checkpoint_path(args.checkpoint_name, args.dataset, args.pretrained_model)

    print("Dataset:", args.dataset)
    print("Epochs:", args.epochs)
    print("Batchsize:", args.batch_size)
    print("Learning rate:", args.learning_rate)
    print("Random seed:", args.seed)
    print("Mask set size:", args.num_mask_locations**2)
    print("Pretrained model:", args.pretrained_model)
    print("Checkpoint path:", checkpoint_path)
    print("\n")

    train_loader, val_loader = get_dataloaders(args.dataset, args.batch_size)
    num_classes = get_num_classes(args.dataset)

    img_size = 224
    adv_patch_size = 39  # 39x39 constitutes 3% pixels of 224x224 image
    stride = math.ceil((img_size - adv_patch_size + 1) / args.num_mask_locations)
    mask_size = adv_patch_size + stride - 1
    masks_set = generate_masks(mask_size, stride, (args.num_mask_locations ** 2), img_size)

    multisize_masks_set = [masks_set]

    # get all the two-mask combination set
    masks_set = 1 - masks_set
    mask_set_size = masks_set.shape[1]
    num_combinations = np.sum(range(mask_set_size + 1))
    masks_set_combinations = torch.zeros(1, num_combinations, img_size, img_size)
    count = 0
    for mask1_idx in range(mask_set_size):
        for mask2_idx in range(mask1_idx, mask_set_size):
            masks_set_combinations[:, count] = 1 - masks_set[:, mask1_idx] * masks_set[:, mask2_idx]
            count += 1
    masks_set_combinations_view = masks_set_combinations.unsqueeze(dim=2).to(device)
    print("Total two-mask unique combinations: ", count)

    if args.pretrained_model == "resnetv2":
        classifier = timm.create_model('resnetv2_50x1_bit_distilled', pretrained=True)
    elif args.pretrained_model == "vit_base":
        classifier = timm.create_model('vit_base_patch16_224', pretrained=True)
    elif args.pretrained_model == "convnext":
        classifier = timm.create_model('convnext_tiny_in22ft1k', pretrained=True)

    if num_classes != 1000:
        classifier.reset_classifier(num_classes=num_classes)

    classifier.cuda()

    if args.task in ["trainval", "train"]:
        if args.resume:
            saved_state = torch.load(checkpoint_path)
            classifier.load_state_dict(saved_state["state_dict"])
            optimizer = saved_state["optimizer"]
            start_epoch = saved_state["epoch"] + 1
            print("Resume training. Loaded latest checkpoint and optimizer from epoch {}.".format(saved_state["epoch"]))
            lr = args.learning_rate if start_epoch <= args.epochs - 5 else args.learning_rate*0.1
            print("Setting the learning rate as ", lr)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        else:
            start_epoch = args.start_epoch
            if args.optim == "sgd":
                optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)
            else:
                optimizer = torch.optim.Adam(classifier.parameters(), lr=args.learning_rate, betas=(0.5, 0.999))
    elif args.task in ["val"]:
        saved_state = torch.load(checkpoint_path)
        classifier.load_state_dict(saved_state["state_dict"])

    if args.task in ["trainval", "train"]:
        for epoch in range(start_epoch, args.epochs):
            train(train_loader, classifier, optimizer, epoch, masks_set_combinations_view, args)
            if "val" in args.task:
                validate_clean(val_loader, classifier)
            save_state_dict(epoch, optimizer, classifier, checkpoint_path)

        optimizer = None
        save_state_dict(epoch, optimizer, classifier, checkpoint_path)
        validate(val_loader, classifier, multisize_masks_set[0])

    elif args.task == "val":
        validate(val_loader, classifier, multisize_masks_set[0])
    else:
        raise RuntimeError("Invalid scenario!")


if __name__ == '__main__':
    main()